package org.example.utils

import com.alibaba.fastjson.JSONObject
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.apache.spark.sql.types._
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
import org.example.client.DbClient
import org.example.constant.ApolloConst
import org.example.dao.MysqlConfig
import scalikejdbc.{NamedDB, SQL}

import java.sql.{Connection, DatabaseMetaData, DriverManager, ResultSet}
import java.util
import scala.collection.mutable
import scala.collection.mutable.{ArrayBuffer, ListBuffer}

/**
 * mysql工具类
 */
object MysqlUtil {

  private val url = ApolloConst.jgdMysqlURL
  private val driver = ApolloConst.jgdMysqlDriver
  private val userName = ApolloConst.jgdMysqlUserName
  private val passWord = ApolloConst.jgdMysqlPassWord


  /**
   * 描述: 获取mysql连接
   *
   * @return mysql连接
   */
  def getMysqlConn(): Connection = {
    Class.forName(driver)
    DriverManager.getConnection(url, userName, passWord)
  }

  /**
   * 描述:关闭mysql连接
   *
   * @param conn mysql连接
   */
  def closeMysqlConn(conn: Connection) = {
    conn.close()
  }

  /**
   * 描述:获取mysql的查询结果
   *
   * @param sql 执行的sql语句
   * @return mysql的查询结果
   */
  def getMysqlQueryResult(sql: String): ResultSet = {
    val connection = getMysqlConn()
    val statement = connection.createStatement()
    statement.executeQuery(sql)
  }

  /**
   * 根据结果集获取数据
   * @param resultSet
   * @return
   */
  def getMysqlQueryRow(resultSet:ResultSet) ={
    val rsmd = resultSet.getMetaData // 获取结果集元数据
    val columnCount = resultSet.getMetaData.getColumnCount // 列数
    val columnTypeList = new util.ArrayList[String]() // 列类型List
    val rowSchemaList = new ListBuffer[StructField]() // 行schema的List
    for (i <- 1 to columnCount) {
      var columnType = rsmd.getColumnClassName(i) // 列类型
      // 如果列类型为Integer，则转换为scala中的Int类型
      columnType = columnType.substring(columnType.lastIndexOf(".") + 1)
      if ("Integer".equals(columnType)) {
        columnType = "Int"
      }
      columnTypeList.add(columnType)
      rowSchemaList += createStructField(rsmd.getColumnName(i), rsmd.getColumnClassName(i))
    }
    val rowSchema = StructType(rowSchemaList) // 列schema
    var rowList = new util.ArrayList[GenericRowWithSchema]() // 行数据集合
    while (resultSet.next()) {
      var row = new ListBuffer[Object] // 当前行数据
      for (i <- 1 to columnCount) {
        row += resultSet.getObject(i)
      }
      rowList.add(new GenericRowWithSchema(row.toArray,rowSchema))
    }

    rowList.iterator()

  }

  /**
   * 描述：将ResultSet转为DataFrame
   *
   * @param resultSet    resultSet
   * @param sparkSession sparkSession
   * @return 转换完的dataframe
   */
  def resultSetToDataframe(resultSet: ResultSet, sparkSession: SparkSession): DataFrame = {
    val rsmd = resultSet.getMetaData // 获取结果集元数据
    val columnCount = resultSet.getMetaData.getColumnCount // 列数
    val columnTypeList = new util.ArrayList[String]() // 列类型List
    val rowSchemaList = new ListBuffer[StructField]() // 行schema的List
    for (i <- 1 to columnCount) {
      var columnType = rsmd.getColumnClassName(i) // 列类型
      // 如果列类型为Integer，则转换为scala中的Int类型
      columnType = columnType.substring(columnType.lastIndexOf(".") + 1)
      if ("Integer".equals(columnType)) {
        columnType = "Int"
      }
      columnTypeList.add(columnType)
      rowSchemaList += createStructField(rsmd.getColumnName(i), rsmd.getColumnClassName(i))
    }
    val rowSchema = StructType(rowSchemaList) // 列schema
    var rowList = new util.ArrayList[Row]() // 行数据集合
    while (resultSet.next()) {
      var row = new ListBuffer[Any] // 当前行数据
      for (i <- 1 to columnCount) {
        row += resultSet.getObject(i)
      }
      rowList.add(Row.fromSeq(row))
    }
    sparkSession.createDataFrame(rowList, rowSchema)
  }


  /**
   *
   * 描述:创建schema
   *
   * @param name    列名
   * @param colType 列类型
   * @return StructField
   */
  private def createStructField(name: String, colType: String): StructField = {
    colType match {
      case "java.lang.String" => StructField(name, StringType)
      case "java.lang.Integer" => StructField(name, IntegerType)
      case "java.lang.Long" => StructField(name, LongType)
      case "java.lang.Boolean" => StructField(name, BooleanType)
      case "java.lang.Double" => StructField(name, DoubleType)
      case "java.lang.Float" => StructField(name, FloatType)
      case "java.sql.Date" => StructField(name, DateType)
      case "java.sql.Time" => StructField(name, TimestampType)
      case "java.sql.Timestamp" => StructField(name, TimestampType)
      case "java.math.BigDecimal" => StructField(name, DecimalType(10, 0))
    }
  }

  /**
   * 描述:将DataFrame写入mysql中
   *
   * @param dataFrame DataFrame
   * @param dbTable   表名
   * @param mode      写入模式
   */
  def saveDataToMysql(dataFrame: DataFrame, dbTable: String, mode: String): Unit = {
    dataFrame
      .write
      .format("jdbc")
      .option("url", url)
      .option("dbtable", dbTable)
      .option("user", userName)
      .option("password", passWord)
      .option("driver", "com.mysql.jdbc.Driver")
      .mode(mode)
      .save()
  }

  /**
   * 从mysql查询数据
   *
   * @param sparkSession SparkSession
   * @param mysqlConfig  Mysql配置
   * @param sql          执行的sql
   */
  def readDataFromMysql(sparkSession: SparkSession, mysqlConfig: MysqlConfig, sql: String): DataFrame = {
    sparkSession.read
      .format(mysqlConfig.format)
      .option("url", mysqlConfig.url)
      .option("user", mysqlConfig.user)
      .option("password", mysqlConfig.password)
      .option("driver", "com.mysql.jdbc.Driver")
      .option("query", sql)
      .load()
  }
  /**
   * 获取msyql表的字段
   * @param database MySQL库
   * @param tableNames MySQL表
   * @return
   */
  def getColsByTable(database:String,tableNames:ArrayBuffer[String]):  mutable.HashMap[String, ArrayBuffer[String]] ={
    var tableColMap = scala.collection.mutable.HashMap("" -> new ArrayBuffer[String]())
    DbClient.init("mysql", ApolloConst.jgdMysqlDriver, ApolloConst.jgdMysqlURL, ApolloConst.jgdMysqlUserName, ApolloConst.jgdMysqlPassWord)
    DbClient.usingDB("mysql") { db: NamedDB =>
      val connection: Connection = db.conn
      val metaData: DatabaseMetaData = connection.getMetaData
      if (database!=null && database.trim.nonEmpty && tableNames!=null && tableNames.nonEmpty){
        tableNames.foreach{tableName: String =>
          val rs: ResultSet = metaData.getColumns(null, database, tableName, "%")
          val cols = new ArrayBuffer[String]()
          while (rs.next()) {
            val column: String = rs.getString("column_name")
            cols += column
          }
          tableColMap += (tableName -> cols)
        }
      }
      connection.close()
    }
    tableColMap
  }

  /**
   * 获取msyql表的字段
   * @param database mysql k库名
   * @param tableNames mysql 表名
   * @param dbName 连接池名称
   * @return
   */
  def getColsByTable(database:String,tableNames:ArrayBuffer[String],dbName:String):  mutable.HashMap[String, ArrayBuffer[String]] ={
    var tableColMap = scala.collection.mutable.HashMap("" -> new ArrayBuffer[String]())
    DbClient.usingDB(dbName) { db: NamedDB =>
      val connection: Connection = db.conn
      val metaData: DatabaseMetaData = connection.getMetaData
      if (database!=null && database.trim.nonEmpty && tableNames!=null && tableNames.nonEmpty){
        tableNames.foreach{tableName: String =>
          val rs: ResultSet = metaData.getColumns(null, database, tableName, "%")
          val cols = new ArrayBuffer[String]()
          while (rs.next()) {
            val column: String = rs.getString("column_name")
            cols += column
          }
          tableColMap += (tableName -> cols)
        }
      }
      connection.close()
    }
    tableColMap
  }

  /**
   * 获取msyql表的字段
   * @param database mysql k库名
   * @param tableName mysql 表名
   * @param dbName 连接池名称
   * @return
   */
  def getColsByTable(database:String,tableName:String,dbName:String): ArrayBuffer[String] ={
    val cols = new ArrayBuffer[String]()
    DbClient.usingDB(dbName) { db: NamedDB =>
      val connection: Connection = db.conn
      val metaData: DatabaseMetaData = connection.getMetaData
      val rs: ResultSet = metaData.getColumns(null, database, tableName, "%")
      while (rs.next()) {
        val column: String = rs.getString("column_name")
        cols += column
      }
      connection.close()
    }
    cols
  }


  /**
   * replace into mysql表
   * 需要保证jsons的key和表的字段名对应，忽略大小写和下划线
   * @param jsons 待插入数据
   * @param mysqlColumns 列名称
   * @param tableName 表名称
   * @param dbName 连接池名称
   */
  def replaceIntoMysql(jsons: RDD[JSONObject], mysqlColumns:ArrayBuffer[String], tableName:String,dbName:String): Unit = {
    if (jsons == null || mysqlColumns == null || jsons.isEmpty() || mysqlColumns.isEmpty) {
      return
    }
    //获取json的所有key
    val jsonKeys = jsons.first.keySet().toArray(Array[String]())
    //获取json的所有key，去掉下划线，全转小写
    val lowJsonKeys: Array[String] = jsonKeys.map((key: String) => key.replaceAll("_", "").toLowerCase)
    //列名称，去掉下划线，全转小写
    val lowMysqlColumns: Array[String] = mysqlColumns.toArray.map(col => col.replaceAll("_", "").toLowerCase)
    //每一条待插入的数据按mysql表的列的顺序排列
    val rt = jsons.map { json: JSONObject =>
      lowMysqlColumns.map { column: String =>
        val idx = lowJsonKeys.indexOf(column)
        if (idx < 0) {
          null
        } else {
          json.getString(jsonKeys(idx))
        }
      }
    }.collect
    val sql = s"replace into $tableName (${mysqlColumns.map((x: String) => "`" + x + "`").mkString(",")}) values(${mysqlColumns.map((_: String) => "?").mkString(",")})"
    //插入库
    DbClient.usingDB(dbName) { db =>
      db autoCommit { implicit session =>
        rt.foreach { one =>
          SQL(sql).bind(one: _*).update().apply()
        }
      }
    }
  }

  /**
   * insert into mysql表
   * 需要保证jsons的key和表的字段名对应，忽略大小写和下划线
   * @param jsons 待插入数据
   * @param mysqlColumns 列名称
   * @param tableName 表名称
   * @param dbName 连接池名称
   */
  def insertIntoMysql(jsons: RDD[JSONObject], mysqlColumns:ArrayBuffer[String], tableName:String,dbName:String): Unit = {
    if (jsons == null || mysqlColumns == null || jsons.isEmpty() || mysqlColumns.isEmpty) {
      return
    }
    //获取json的所有key
    val jsonKeys = jsons.first.keySet().toArray(Array[String]())
    //获取json的所有key，去掉下划线，全转小写
    val lowJsonKeys: Array[String] = jsonKeys.map((key: String) => key.replaceAll("_", "").toLowerCase)
    //列名称，去掉下划线，全转小写
    val lowMysqlColumns: Array[String] = mysqlColumns.toArray.map(col => col.replaceAll("_", "").toLowerCase)
    //每一条待插入的数据按mysql表的列的顺序排列
    val rt = jsons.map { json: JSONObject =>
      lowMysqlColumns.map { column: String =>
        val idx = lowJsonKeys.indexOf(column)
        if (idx < 0) {
          null
        } else {
          json.getString(jsonKeys(idx))
        }
      }
    }.collect
    val sql = s"insert into $tableName (${mysqlColumns.map((x: String) => "`" + x + "`").mkString(",")}) values(${mysqlColumns.map((_: String) => "?").mkString(",")})"
    //插入库
    DbClient.usingDB(dbName) { db =>
      db autoCommit { implicit session =>
        rt.foreach { one =>
          SQL(sql).bind(one: _*).update().apply()
        }
      }
    }
  }

  /**
   * replace into mysql表
   * 需要保证jsons的key和表的字段名对应，忽略大小写和下划线
   * @param jsons 待插入数据
   * @param mysqlColumns 列名称
   * @param tableName 表名称
   * @param dbName 连接池名称
   */
  def replaceIntoMysql(jsons: Array[JSONObject], mysqlColumns:ArrayBuffer[String], tableName:String,dbName:String): Unit = {
    if (jsons == null || mysqlColumns == null || jsons.isEmpty || mysqlColumns.isEmpty) {
      return
    }
    //获取json的所有key
    val jsonKeys = jsons.head.keySet().toArray(Array[String]())
    //println("jsonKeys:"+jsonKeys.mkString(","))
    //获取json的所有key，去掉下划线，全转小写
    val lowJsonKeys: Array[String] = jsonKeys.map((key: String) => key.replaceAll("_", "").toLowerCase)
    //println("lowJsonKeys:"+lowJsonKeys.mkString(","))
    //列名称，去掉下划线，全转小写
    val lowMysqlColumns: Array[String] = mysqlColumns.toArray.map(col => col.replaceAll("_", "").toLowerCase)
    //println("lowMysqlColumns:"+lowMysqlColumns.mkString(","))
    //每一条待插入的数据按mysql表的列的顺序排列
    val rt = jsons.map { json: JSONObject =>
      lowMysqlColumns.map { column: String =>
        val idx = lowJsonKeys.indexOf(column)
        if (idx < 0) {
          null
        } else {
          json.getString(jsonKeys(idx))
        }
      }
    }
    //rt.foreach(x=>println(x.mkString(",")))
    val sql = s"replace into $tableName (${mysqlColumns.map((x: String) => "`" + x + "`").mkString(",")}) values(${mysqlColumns.map((_: String) => "?").mkString(",")})"
    //插入库
    DbClient.usingDB(dbName) { db =>
      db autoCommit { implicit session =>
        rt.foreach { one =>
          SQL(sql).bind(one: _*).update().apply()
        }
      }
    }
  }

  /**
   * replace into mysql表
   * 需要保证jsons的key和表的字段名对应，忽略大小写和下划线
   * @param jsons 待插入数据
   * @param mysqlColumns 列名称
   * @param tableName 表名称
   * @param dbName 连接池名称
   */
  def insertIntoMysql(jsons: Array[JSONObject], mysqlColumns:ArrayBuffer[String], tableName:String,dbName:String): Unit = {
    if (jsons == null || mysqlColumns == null || jsons.isEmpty || mysqlColumns.isEmpty) {
      return
    }
    //获取json的所有key
    val jsonKeys = jsons.head.keySet().toArray(Array[String]())
    //获取json的所有key，去掉下划线，全转小写
    val lowJsonKeys: Array[String] = jsonKeys.map((key: String) => key.replaceAll("_", "").toLowerCase)
    //列名称，去掉下划线，全转小写
    val lowMysqlColumns: Array[String] = mysqlColumns.toArray.map(col => col.replaceAll("_", "").toLowerCase)
    //每一条待插入的数据按mysql表的列的顺序排列
    val rt = jsons.map { json: JSONObject =>
      lowMysqlColumns.map { column: String =>
        val idx = lowJsonKeys.indexOf(column)
        if (idx < 0) {
          null
        } else {
          json.getString(jsonKeys(idx))
        }
      }
    }
    val sql = s"insert into $tableName (${mysqlColumns.map((x: String) => "`" + x + "`").mkString(",")}) values(${mysqlColumns.map((_: String) => "?").mkString(",")})"
    //插入库
    DbClient.usingDB(dbName) { db =>
      db autoCommit { implicit session =>
        rt.foreach { one =>
          SQL(sql).bind(one: _*).update().apply()
        }
      }
    }
  }
  def main(args: Array[String]): Unit = {

    val sql =
      """
        |select * from hzcp_itms.base_into_vehicle_info
      """.stripMargin
    val set: ResultSet = getMysqlQueryResult(sql)
    println("======" + set)

  }


}
